{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Approximating the **Lyapunov Spectrum** of the Lorenz system using JAX's autodiff\n",
    "\n",
    "The Lorenz equations are a prototypical example of **deterministic chaos**. They\n",
    "are a system of three **nonlinear** ODEs\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "\\frac{dx}{dt} &= \\sigma(y - x), \\\\\n",
    "\\frac{dy}{dt} &= x(\\rho - z) - y, \\\\\n",
    "\\frac{dz}{dt} &= xy - \\beta z.\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "The three variables can be combined into the state vector $u = (x, y, z) \\in\n",
    "\\mathbb{R}^3$. A system with N degrees of freedom has N Lyapunov exponents,\n",
    "forming the **Lyapunov Spectrum**. To assess chaotic perturbuation growth,\n",
    "typically only the **largest** Lyapunov exponent $\\lambda$ is relevant\n",
    "\n",
    "$$\n",
    "\\| \\delta u(t) \\| \\approx \\| \\delta u(0) \\| \\exp(\\max(\\lambda_i) t).\n",
    "$$\n",
    "\n",
    "The entire spectrum $\\{\\lambda_i\\}$ can, e.g., be used to compute the\n",
    "Kaplan-Yorke dimension and analyze other properties of the system.\n",
    "\n",
    "In this notebook, we will approximate the spectrum for the Lorenz system under\n",
    "the original configuration $\\sigma = 10$, $\\rho = 28$ and $\\beta = 8/3$ (1)\n",
    "using a [Runge-Kutta 4\n",
    "simulator](https://github.com/Ceyron/machine-learning-and-simulation/blob/main/english/simulation_scripts/lorenz_simulator_numpy.ipynb)\n",
    "of time step size $\\Delta t = 0.01$. Let's call the discrete time stepper\n",
    "$\\mathcal{P}$ that advances from one time level $u^{[t]}$ to the next\n",
    "$u^{[t+1]}$.\n",
    "\n",
    "Then, we can approximate the spectrum $\\{\\lambda_0, \\lambda_1, \\lambda_2\\}$ by the following strategy:\n",
    "\n",
    "1. Draw a reasonable initial condition, e.g. $u^{[0]} = (1, 1, 1)$.\n",
    "2. Evolve the initial condition until it enters the chaotic attractor, e.g., by\n",
    "   using $5000$ time steps to get $u^{[5000]}$. Use the last state $u^{[5000]}$\n",
    "   as the \"warmed-up\" initial state $u^{[0]} \\leftarrow u^{[5000]}$.\n",
    "4. Introduce a perturbation **matrix** $Y^{[0]} \\in \\mathbb{R}^{3 \\times\n",
    "   3}$:\n",
    "   1. For example, draw it randomly from a normal distribution, $Y^{[0]}_{ij}\n",
    "      \\sim \\mathcal{N}(0, 1)$.\n",
    "   2. Orthonormalize it using a **QR decomposition**:\n",
    "      1. $Q, R = \\text{QR}((Y^{[0]}))$\n",
    "      2. $Y^{[0]} \\leftarrow Q$\n",
    "5. Evolve $u^{[t]}$ via the Runge-Kutta 4 stepper $u^{[t+1]} =\n",
    "   \\mathcal{P}(u^{[t]})$. At the same time, at each time step $t$:\n",
    "   1. Compute the Jacobian of the time stepper evaluated at the current state\n",
    "      $J_\\mathcal{P}(u^{[t]})$. \n",
    "   2. Then, evolve the perturbation $Y^{[t]}$ via the $Y^{[t+1]} =\n",
    "      J_\\mathcal{P}(u^{[t]}) Y^{[t]}$. (*)\n",
    "   3. Re-orthonormalize $Y^{[t+1]}$ the perturbation matrix using a QR\n",
    "      decomposition and record the diagonal of\n",
    "      $R$:\n",
    "      1. $Q, R = \\text{QR}((Y^{[t+1]}))$\n",
    "      2. $Y^{[t+1]} \\leftarrow Q$\n",
    "      3. $\\epsilon^{[t+1]} = \\text{diag}(R)$\n",
    "6. Do this for a certain number of time steps, e.g. $50000$, and record the\n",
    "   growth factors $\\epsilon^{[t+1]} \\in \\mathbb{R}^{3}$.\n",
    "7. Approximate the Lyapunov spectrum via\n",
    "\n",
    "$$\n",
    "\\lambda_i = \\frac{1}{\\Delta t}\\frac{1}{T} \\sum_{t=0}^{T} \\log |\\epsilon^{[t+1]}_i|.\n",
    "$$\n",
    "\n",
    "(*) Instead of instantiating the full (and oftentimes dense) Jacobian matrix\n",
    "$J_\\mathcal{P}(u^{[t]})$ at each time step, we can also use `jax.vmap` on\n",
    "`jax.linearize`.\n",
    "\n",
    "---\n",
    "\n",
    "(1) E. N. Lorenz, \"Deterministic Nonperiodic Flow\", Journal of the Atmospheric\n",
    "Sciences, 1963,\n",
    "https://journals.ametsoc.org/view/journals/atsc/20/2/1520-0469_1963_020_0130_dnf_2_0_co_2.xml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lorenz_rhs(u, *, sigma, rho, beta):\n",
    "    x, y, z = u\n",
    "    x_dot = sigma * (y - x)\n",
    "    y_dot = x * (rho - z) - y\n",
    "    z_dot = x * y - beta * z\n",
    "    u_dot = jnp.array([x_dot, y_dot, z_dot])\n",
    "    return u_dot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LorenzStepperRK4:\n",
    "    def __init__(self, dt=0.01, *, sigma=10, rho=28, beta=8/3):\n",
    "        self.dt = dt\n",
    "        self.sigma = sigma\n",
    "        self.rho = rho\n",
    "        self.beta = beta\n",
    "    \n",
    "    def __call__(self, u_prev):\n",
    "        lorenz_rhs_fixed = lambda u: lorenz_rhs(\n",
    "            u,\n",
    "            sigma=self.sigma,\n",
    "            rho=self.rho,\n",
    "            beta=self.beta,\n",
    "        )\n",
    "        k_1 = lorenz_rhs_fixed(u_prev)\n",
    "        k_2 = lorenz_rhs_fixed(u_prev + 0.5 * self.dt * k_1)\n",
    "        k_3 = lorenz_rhs_fixed(u_prev + 0.5 * self.dt * k_2)\n",
    "        k_4 = lorenz_rhs_fixed(u_prev + self.dt * k_3)\n",
    "        u_next = u_prev + self.dt * (k_1 + 2*k_2 + 2*k_3 + k_4)/6\n",
    "        return u_next"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "lorenz_stepper = LorenzStepperRK4()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "u_0 = jnp.array([1.0, 1.0, 1.0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([1.0125672, 1.2599177, 0.984891 ], dtype=float32)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lorenz_stepper(u_0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rollout(stepper, n, *, include_init: bool = False):\n",
    "    def scan_fn(u, _):\n",
    "        u_next = stepper(u)\n",
    "        return u_next, u_next\n",
    "\n",
    "    def rollout_fn(u_0):\n",
    "        _, trj = jax.lax.scan(scan_fn, u_0, None, length=n)\n",
    "\n",
    "        if include_init:\n",
    "            return jnp.concatenate([jnp.expand_dims(u_0, axis=0), trj], axis=0)\n",
    "\n",
    "        return trj\n",
    "\n",
    "    return rollout_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "trj = rollout(lorenz_stepper, 5000, include_init=True)(u_0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "u_warmed = trj[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def push_orthonormal_matrix_variation(stepper, u_0, Y_0, n):\n",
    "    def scan_fn(carry, _):\n",
    "        u, Y = carry\n",
    "\n",
    "        # Jacobian instantiation\n",
    "        # u_next = stepper(u)\n",
    "        # jac = jax.jacfwd(stepper)(u)\n",
    "        # Y_next = jac @ Y\n",
    "\n",
    "        # More efficient approach\n",
    "        u_next, jvp_fn = jax.linearize(stepper, u)\n",
    "        Y_next = jax.vmap(jvp_fn, in_axes=-1, out_axes=-1)(Y)\n",
    "\n",
    "        Q, R = jnp.linalg.qr(Y_next)\n",
    "        Y_next = Q\n",
    "        growth = jnp.diag(R)\n",
    "\n",
    "        carry_next = (u_next, Y_next)\n",
    "\n",
    "        return carry_next, growth\n",
    "    \n",
    "    Q, _ = jnp.linalg.qr(Y_0)\n",
    "\n",
    "    initial_carry = (u_0, Q)\n",
    "\n",
    "    _, growth_trj = jax.lax.scan(\n",
    "        scan_fn,\n",
    "        initial_carry,\n",
    "        None,\n",
    "        length=n,\n",
    "    )\n",
    "\n",
    "    return growth_trj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[ 1.6226422 ,  2.0252647 , -0.43359444],\n",
       "       [-0.07861735,  0.1760909 , -0.97208923],\n",
       "       [-0.49529874,  0.4943786 ,  0.6643493 ]], dtype=float32)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_0 = jax.random.normal(jax.random.key(0), (3, 3))\n",
    "Y_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "growth_trj = push_orthonormal_matrix_variation(\n",
    "    lorenz_stepper,\n",
    "    u_warmed,\n",
    "    Y_0,\n",
    "    50_000\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000, 3)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "growth_trj.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3,)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unscaled_lyapunov_spectrum = jnp.mean(\n",
    "    jnp.log(jnp.abs(growth_trj)),\n",
    "    axis=0,\n",
    ")\n",
    "unscaled_lyapunov_spectrum.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([ 9.039051e-03, -3.169071e-05, -1.456730e-01], dtype=float32)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unscaled_lyapunov_spectrum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "lyapunov_spectrum = unscaled_lyapunov_spectrum / lorenz_stepper.dt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([ 9.0400004e-01, -3.0000000e-03, -1.4567000e+01], dtype=float32)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lyapunov_spectrum.round(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "youtube",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
